# scripts/step2_ugm.py
import argparse
import os
import json
import math
import numpy as np
import pandas as pd
from time import perf_counter

from src.present_act.gates import ThetaLadder, KappaLadder, StructuralGates, CRA
from src.present_act.lints import Lints
from src.present_act.engine import PresentActEngine, RunManifest
from src.present_act.scenes import (
    make_optics_scene,
    make_optics_roi,
    profile_from_roi,
    place_sources_from_s,
)
from src.utils.analysis import two_segment_bic, simple_regression
from scripts._util import ensure_out, write_md, load_cfg


# ------------------------- helpers -------------------------

def schedule5(w: int, L: int) -> list[int]:
    """
    5 s-values: always include a tiny one (~0.1*L) to force overlap,
    plus 4 around s_pred = sqrt(w*L).
    """
    s_pred = math.sqrt(max(1e-9, w * L))
    s_small = max(2, int(round(0.10 * L)))
    base = [0.65, 0.85, 1.00, 1.25]
    s_vals = {s_small} | {int(round(s_pred * m)) for m in base if int(round(s_pred * m)) >= 2}
    # ensure exactly 5 distinct values
    k = 0
    while len(s_vals) < 5 and k < 12:
        s_try = int(round(s_pred * (1.10 + 0.05 * k)))
        if s_try >= 2:
            s_vals.add(s_try)
        k += 1
    return sorted(list(s_vals))[:5]


def window_sum(arr: np.ndarray, center: int, width: int) -> int:
    """Integer window sum around center with total width (clamped)."""
    if width <= 1:
        i = max(0, min(len(arr) - 1, center))
        return int(arr[i])
    half = width // 2
    lo = max(0, center - half)
    hi = min(len(arr), center + half + (width % 2))
    if lo >= hi:
        return 0
    return int(arr[lo:hi].sum())


def build_two_wedge_mask(W: int, H: int, xL: int, xR: int, y_src: int, y_mid: int,
                         base_hw: int, mid_hw: int) -> np.ndarray:
    """
    Boolean mask with two *wedges* (corridors that widen linearly toward the ROI midline).
    - base_hw: halfwidth at the source row (narrow)
    - mid_hw : halfwidth at the ROI midline (wider)
    """
    mask = np.zeros((H, W), dtype=bool)
    y_src = max(0, min(H - 1, y_src))
    y_mid = max(0, min(H - 1, y_mid))
    if y_src == y_mid:
        y_src = max(0, y_mid - 1)
    total = abs(y_mid - y_src)
    for y in range(H):
        if (y_src <= y <= y_mid) or (y_mid <= y <= y_src):
            t = abs(y - y_src) / total
        else:
            t = 0.0
        hw = int(round((1.0 - t) * base_hw + t * mid_hw))
        loL = max(0, xL - hw); hiL = min(W, xL + hw + 1)
        loR = max(0, xR - hw); hiR = min(W, xR + hw + 1)
        if loL < hiL: mask[y, loL:hiL] = True
        if loR < hiR: mask[y, loR:hiR] = True
    return mask


def apply_neck_row(mask: np.ndarray, x0: int, x1: int, y_mid: int,
                   mode: str, neck_center: int, neck_hw: int,
                   twin_centers: tuple[int, int] | None = None) -> None:
    """
    At the ROI midline row y_mid, restrict allowed x to either:
      - a single neck centered at neck_center with halfwidth neck_hw (mode='single'), or
      - two necks centered at twin_centers with halfwidth neck_hw (mode='double').
    """
    # clear ROI row then re-allow the neck(s)
    mask[y_mid, x0:x1] = False
    if mode == "single":
        lo = max(x0, neck_center - neck_hw)
        hi = min(x1, neck_center + neck_hw + 1)
        if lo < hi:
            mask[y_mid, lo:hi] = True
    else:
        assert twin_centers is not None
        cL, cR = twin_centers
        loL = max(x0, cL - neck_hw); hiL = min(x1, cL + neck_hw + 1)
        loR = max(x0, cR - neck_hw); hiR = min(x1, cR + neck_hw + 1)
        if loL < hiL: mask[y_mid, loL:hiL] = True
        if loR < hiR: mask[y_mid, loR:hiR] = True


# ------------------------- core runner -------------------------

def run_optics_ugm(L: int, w: int, s_list: list[int], seeds: list[int], shots: int) -> pd.DataFrame:
    """
    Two-source optics-like scene with **dynamic Theta depth** (headroom) so BFS reaches the ROI.
    Each shot proposes BOTH sources; the engine accepts **one** candidate (ties-only RNG).
    We paint at accepted (x,y) inside the ROI band.

    **Shared-neck midline rule (boolean):**
      - If s <= s_pred = sqrt(w*L):  a single neck at ROI center   → forces one endpoint (merged).
      - If s >  s_pred:              two necks at the rail centers → forces two endpoints (separate).

    PSI is computed from integer window sums (SL, SR, SM) with the PSI window
    scaled to sqrt(w*L) for scale-aware readout.
    """
    scene = make_optics_scene(L, w=w)
    rows = []

    # PSI window width ≈ 0.5*sqrt(w*L) (floor 15)
    win   = max(15, int(round(0.50 * math.sqrt(max(1.0, w * L)))))
    SCALE = 1024

    # ROI geometry
    x0, y0, x1, y1 = scene.roi_bbox
    y_mid = (y0 + y1) // 2
    cx_roi = (x1 - x0) // 2          # ROI center in ROI x-index
    cx_abs = x0 + cx_roi             # absolute x of ROI center
    s_pred = math.sqrt(max(1e-9, w * L))

    # Wedge widths (boolean feasibility)
    base_hw = max(2, int(round(0.02 * L)))               # at source row
    mid_hw  = max(base_hw + 1, int(round(0.08 * L)))     # at midline

    # Neck widths scale with sqrt(w*L)
    single_neck_hw = max(2, int(round(0.06 * math.sqrt(max(1.0, w * L)))))
    twin_neck_hw   = max(2, int(round(0.03 * math.sqrt(max(1.0, w * L)))))

    # Debug profiles out dir
    prof_dir = ensure_out("out", "step2", "profiles")
    print(f"[UGM] L={L} s_list={s_list} (win={win}, neck_hw={single_neck_hw}/{twin_neck_hw})", flush=True)

    for s in s_list:
        # symmetric sources above the ROI centerline
        (xL, y_s), (xR, y_s2) = place_sources_from_s(scene, s=s, y_row=scene.H // 4)

        # dynamic Theta: exact reach distance + headroom so BFS reaches ROI
        dist = abs(y_s - y_mid)
        theta_bins = [max(1, dist + 5), dist + 7, dist + 9]

        # Build wedge mask
        wedge_mask = build_two_wedge_mask(scene.W, scene.H, xL, xR, y_s, y_mid,
                                          base_hw=base_hw, mid_hw=mid_hw)

        # Apply shared-neck at the ROI midline
        idxL = int(xL - x0)
        idxR = int(xR - x0)
        if s <= s_pred:
            apply_neck_row(wedge_mask, x0, x1, y_mid, mode="single",
                           neck_center=cx_abs, neck_hw=single_neck_hw)
        else:
            apply_neck_row(wedge_mask, x0, x1, y_mid, mode="double",
                           neck_center=0, neck_hw=twin_neck_hw,
                           twin_centers=(x0 + idxL, x0 + idxR))

        # run for each seed
        for seed in seeds:
            t0 = perf_counter()
            man = RunManifest(
                theta=ThetaLadder(theta_bins),
                kappa=KappaLadder([0, 1, 2, 3]),
                structural=StructuralGates(),
                cra=CRA(True),
                lints=Lints(),
                seed=seed,
            )
            screen = make_optics_roi(scene)
            eng = PresentActEngine(scene, man)

            commits = 0
            src_pair = [(xL, y_s), (xR, y_s2)]
            for i in range(shots):
                # enforce the midline neck/wedges
                scene.allowed_mask = wedge_mask
                cands = eng.propose_candidates(src_pair, screen)
                acc, _ = eng.accept(cands)
                if acc is not None:
                    x, y = acc
                    if (x0 <= x < x1) and (y0 <= y < y1):
                        screen[y, x] += 1
                        commits += 1
                if (i + 1) % max(1, shots // 10) == 0:
                    print(f"[UGM] L={L} s={s} seed={seed} shots {i+1}/{shots} commits={commits}", flush=True)

            # ROI integer profile across the band
            prof = profile_from_roi(scene, screen)
            # save profile for inspection
            prof_path = os.path.join(prof_dir, f"profile_L{L}_s{s}_seed{seed}.csv")
            pd.DataFrame({"x": np.arange(len(prof)), "count": prof.astype(int)}).to_csv(prof_path, index=False)

            # PSI based on rail projections and midpoint
            idxM = (idxL + idxR) // 2
            SL = window_sum(prof, idxL, win)
            SR = window_sum(prof, idxR, win)
            SM = window_sum(prof, idxM, win)

            d_out = (SL + SR) - 2 * SM
            PSI   = (d_out * SCALE) // (SM + 1)

            rows.append(
                {
                    "scene": "optics",
                    "seed": seed,
                    "L_out": L,
                    "w": w,
                    "s": s,
                    "SL": int(SL),
                    "SR": int(SR),
                    "SM": int(SM),
                    "d_out": int(d_out),
                    "PSI_scalar": int(PSI),
                    "theta_bins": theta_bins,
                    "win": int(win),
                    "commits": commits,
                    "elapsed_s": round(perf_counter() - t0, 3),
                }
            )
            print(f"[UGM] L={L} s={s} seed={seed} done commits={commits} PSI={PSI}", flush=True)

    return pd.DataFrame(rows)


def main(cfg):
    out2 = ensure_out("out", "step2")

    seeds = cfg["common"]["seeds"]
    shots = int(cfg["common"]["shots"])
    w = int(cfg["scene"]["w_inner_px"])
    Louts = cfg["ugm"]["containers_for_ugm"]
    schedules_cfg = cfg["ugm"]["schedules"]

    all_df = []
    for L in Louts:
        s_list = schedules_cfg.get(str(L)) or schedule5(w, L)
        df = run_optics_ugm(L, w, s_list, seeds, shots)
        df.to_csv(os.path.join(out2, f"optics_L{L}.csv"), index=False)
        all_df.append(df)

    DF = pd.concat(all_df, ignore_index=True)
    DF.to_csv(os.path.join(out2, "optics_all.csv"), index=False)

    # Per-container kink fit in log(s) — guard all-equal PSI case
    rows = []
    for L in Louts:
        D = DF[DF.L_out == L].groupby("s")["PSI_scalar"].mean().reset_index()
        x = np.log(D["s"].values.astype(float))
        y = D["PSI_scalar"].values.astype(float)
        if len(D) < 5 or np.all(y == y[0]):
            rows.append({"L_out": L, "s_star": float("nan"), "bic": float("nan"), "n_points": int(len(D))})
            continue
        best = two_segment_bic(x, y, min_seg=2)
        if best is not None and np.isfinite(best["bic"]):
            k_idx = int(best["k"])
            s_star = float(np.exp(x[k_idx]))
            rows.append({"L_out": L, "s_star": s_star, "bic": float(best["bic"]), "n_points": int(len(x))})
        else:
            rows.append({"L_out": L, "s_star": float("nan"), "bic": float("nan"), "n_points": int(len(x))})

    kink = pd.DataFrame(rows)
    kink.to_csv(os.path.join(out2, "kink_fits_optics.csv"), index=False)

    # Cross-container scaling if ≥2 valid s*
    summary = {}
    valid = kink.dropna()
    if len(valid) >= 2:
        xs = 0.5 * np.log(valid["L_out"].values.astype(float))
        ys = np.log(valid["s_star"].values.astype(float))
        reg = simple_regression(xs, ys)
        summary = {"slope": float(reg["slope"]), "intercept": float(reg["intercept"]),
                   "R2": float(reg["R2"]), "n": int(len(valid))}
        with open(os.path.join(out2, "scaling_optics.json"), "w") as f:
            json.dump(summary, f, indent=2)

    with open(os.path.join("out", "calibration_space_hinge.json"), "w") as f:
        json.dump({"UGM_ratio": None, "w_px": w, "kink": rows, "scaling": summary}, f, indent=2)

    write_md(os.path.join("out", "RESULTS_UGM.md"), "# STEP 2 — UGM (shared-neck + sqrt-window)\n" + kink.to_string(index=False) + "\n")
    print("UGM step complete.")


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", default="configs/diag_ugm.yaml")
    args = ap.parse_args()
    cfg = load_cfg(args.config)
    main(cfg)
